{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vs3a5tGVAWGI"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HYfsarcYBJQp"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aOpqCFEyBQDd"
      },
      "source": [
        "# Uncertainty-aware Deep Language Learning with BERT-SNGP"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6MlSYP6cBT61"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/text/tutorials/uncertainty_quantification_with_sngp_bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/text/blob/master/docs/tutorials/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/text/docs/tutorials/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-IM5IzM26GBh"
      },
      "source": [
        "In the [SNGP tutorial](https://www.tensorflow.org/tutorials/understanding/sngp), you learned how to build SNGP model on top of a deep residual network to improve its ability to quantify its uncertainty. In this tutorial, you will apply SNGP to a natural language understanding (NLU) task by building it on top of a deep BERT encoder to improve deep NLU model's ability in detecting out-of-scope queries. \n",
        "\n",
        "Specifically, you will:\n",
        "* Build BERT-SNGP, a SNGP-augmented [BERT](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2) model.\n",
        "* Load the [CLINC Out-of-scope (OOS)](https://www.tensorflow.org/datasets/catalog/clinc_oos) intent detection dataset.\n",
        "* Train the BERT-SNGP model.\n",
        "* Evaluate the BERT-SNGP model's performance in uncertainty calibration and out-of-domain detection.\n",
        "\n",
        "Beyond CLINC OOS, the SNGP model has been applied to large-scale datasets such as [Jigsaw toxicity detection](https://www.tensorflow.org/datasets/catalog/wikipedia_toxicity_subtypes), and to the image datasets such as [CIFAR-100](https://www.tensorflow.org/datasets/catalog/cifar100) and [ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012). \n",
        "For benchmark results of SNGP and other uncertainty methods, as well as high-quality implementation with end-to-end training / evaluation scripts, you can check out the [Uncertainty Baselines](https://github.com/google/uncertainty-baselines) benchmark."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-bsids4eAYYI"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l2dCK-rbYXsb"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y tensorflow tf-text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MmlftNekWmKR"
      },
      "outputs": [],
      "source": [
        "!pip install \"tensorflow-text==2.11.*\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3sgnLBKk7iuR"
      },
      "outputs": [],
      "source": [
        "!pip install -U tf-models-official==2.11.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M42dnVSk7dVy"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import sklearn.metrics\n",
        "import sklearn.calibration\n",
        "\n",
        "import tensorflow_hub as hub\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "import official.nlp.modeling.layers as layers\n",
        "import official.nlp.optimization as optimization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4TiolAXow5Rs"
      },
      "source": [
        "This tutorial needs the GPU to run efficiently. Check if the GPU is available. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "18dxUFtEBeIR"
      },
      "outputs": [],
      "source": [
        "tf.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9enQL-rZxGkP"
      },
      "outputs": [],
      "source": [
        "gpus = tf.config.list_physical_devices('GPU')\n",
        "gpus"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZY_xLQnS-6ar"
      },
      "outputs": [],
      "source": [
        "assert gpus, \"\"\"\n",
        "  No GPU(s) found! This tutorial will take many hours to run without a GPU.\n",
        "\n",
        "  You may hit this error if the installed tensorflow package is not\n",
        "  compatible with the CUDA and CUDNN versions.\"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cnRQfguq6GZj"
      },
      "source": [
        "First implement a standard BERT classifier following the [classify text with BERT](https://www.tensorflow.org/tutorials/text/classify_text_with_bert) tutorial. We will use the [BERT-base](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3) encoder, and the built-in [`ClassificationHead`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/cls_head.py) as the classifier."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bNBEGs7s6NHB"
      },
      "outputs": [],
      "source": [
        "#@title Standard BERT model\n",
        "\n",
        "PREPROCESS_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'\n",
        "MODEL_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3'\n",
        "\n",
        "class BertClassifier(tf.keras.Model):\n",
        "  def __init__(self, \n",
        "               num_classes=150, inner_dim=768, dropout_rate=0.1,\n",
        "               **classifier_kwargs):\n",
        "    \n",
        "    super().__init__()\n",
        "    self.classifier_kwargs = classifier_kwargs\n",
        "\n",
        "    # Initiate the BERT encoder components.\n",
        "    self.bert_preprocessor = hub.KerasLayer(PREPROCESS_HANDLE, name='preprocessing')\n",
        "    self.bert_hidden_layer = hub.KerasLayer(MODEL_HANDLE, trainable=True, name='bert_encoder')\n",
        "\n",
        "    # Defines the encoder and classification layers.\n",
        "    self.bert_encoder = self.make_bert_encoder()\n",
        "    self.classifier = self.make_classification_head(num_classes, inner_dim, dropout_rate)\n",
        "\n",
        "  def make_bert_encoder(self):\n",
        "    text_inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')\n",
        "    encoder_inputs = self.bert_preprocessor(text_inputs)\n",
        "    encoder_outputs = self.bert_hidden_layer(encoder_inputs)\n",
        "    return tf.keras.Model(text_inputs, encoder_outputs)\n",
        "\n",
        "  def make_classification_head(self, num_classes, inner_dim, dropout_rate):\n",
        "    return layers.ClassificationHead(\n",
        "        num_classes=num_classes, \n",
        "        inner_dim=inner_dim,\n",
        "        dropout_rate=dropout_rate,\n",
        "        **self.classifier_kwargs)\n",
        "\n",
        "  def call(self, inputs, **kwargs):\n",
        "    encoder_outputs = self.bert_encoder(inputs)\n",
        "    classifier_inputs = encoder_outputs['sequence_output']\n",
        "    return self.classifier(classifier_inputs, **kwargs)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SbhbNbKk6WNR"
      },
      "source": [
        "### Build SNGP model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p7YakN0V6Oif"
      },
      "source": [
        "To implement a BERT-SNGP model, you only need to replace the `ClassificationHead` with the built-in [`GaussianProcessClassificationHead`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/cls_head.py). Spectral normalization is already pre-packaged into this classification head. Like in the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp), add a covariance reset callback to the model, so the model automatically reset the covariance estimator at the beginning of a new epoch to avoid counting the same data twice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QCaJy85y8WeE"
      },
      "outputs": [],
      "source": [
        "class ResetCovarianceCallback(tf.keras.callbacks.Callback):\n",
        "\n",
        "  def on_epoch_begin(self, epoch, logs=None):\n",
        "    \"\"\"Resets covariance matrix at the beginning of the epoch.\"\"\"\n",
        "    if epoch \u003e 0:\n",
        "      self.model.classifier.reset_covariance_matrix()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YoHgOuiZ6Q4y"
      },
      "outputs": [],
      "source": [
        "class SNGPBertClassifier(BertClassifier):\n",
        "\n",
        "  def make_classification_head(self, num_classes, inner_dim, dropout_rate):\n",
        "    return layers.GaussianProcessClassificationHead(\n",
        "        num_classes=num_classes, \n",
        "        inner_dim=inner_dim,\n",
        "        dropout_rate=dropout_rate,\n",
        "        gp_cov_momentum=-1,\n",
        "        temperature=30.,\n",
        "        **self.classifier_kwargs)\n",
        "\n",
        "  def fit(self, *args, **kwargs):\n",
        "    \"\"\"Adds ResetCovarianceCallback to model callbacks.\"\"\"\n",
        "    kwargs['callbacks'] = list(kwargs.get('callbacks', []))\n",
        "    kwargs['callbacks'].append(ResetCovarianceCallback())\n",
        "\n",
        "    return super().fit(*args, **kwargs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UOj5YWTt6dCe"
      },
      "source": [
        "Note: The `GaussianProcessClassificationHead` takes a new argument `temperature`. It corresponds to the $\\lambda$ parameter in the __mean-field approximation__ introduced in the [SNGP tutorial](https://www.tensorflow.org/tutorials/understanding/sngp). In practice, this value is usually treated as a hyperparameter, and is finetuned to optimize the model's calibration performance."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qdU90uDT6hFq"
      },
      "source": [
        "### Load CLINC OOS dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AnuNeyHw6kH7"
      },
      "source": [
        "Now load the [CLINC OOS](https://www.tensorflow.org/datasets/catalog/clinc_oos) intent detection dataset. This dataset contains 15000 user's spoken queries collected over 150 intent classes, it also contains 1000 out-of-domain (OOD) sentences that are not covered by any of the known classes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mkMZN2iA6hhg"
      },
      "outputs": [],
      "source": [
        "(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(\n",
        "    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UJSL2nm8Bo02"
      },
      "source": [
        "Make the train and test data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cgkOOZOq6fQL"
      },
      "outputs": [],
      "source": [
        "train_examples = clinc_train['text']\n",
        "train_labels = clinc_train['intent']\n",
        "\n",
        "# Makes the in-domain (IND) evaluation data.\n",
        "ind_eval_data = (clinc_test['text'], clinc_test['intent'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kw76f6caBq_E"
      },
      "source": [
        "Create a OOD evaluation dataset. For this, combine the in-domain test data `clinc_test` and the out-of-domain data `clinc_test_oos`. We will also assign label 0 to the in-domain examples, and label 1 to the out-of-domain examples. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uVFuzecR64FJ"
      },
      "outputs": [],
      "source": [
        "test_data_size = ds_info.splits['test'].num_examples\n",
        "oos_data_size = ds_info.splits['test_oos'].num_examples\n",
        "\n",
        "# Combines the in-domain and out-of-domain test examples.\n",
        "oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)\n",
        "oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)\n",
        "\n",
        "# Converts into a TF dataset.\n",
        "ood_eval_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "    {\"text\": oos_texts, \"label\": oos_labels})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZcHwfwfU6qCE"
      },
      "source": [
        "### Train and evaluate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VTY6KYc6sBB"
      },
      "source": [
        "First set up the basic training configurations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_-uUkUtk6qWC"
      },
      "outputs": [],
      "source": [
        "TRAIN_EPOCHS = 3\n",
        "TRAIN_BATCH_SIZE = 32\n",
        "EVAL_BATCH_SIZE = 256"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tiEjMdFV6wXQ"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "\n",
        "def bert_optimizer(learning_rate, \n",
        "                   batch_size=TRAIN_BATCH_SIZE, epochs=TRAIN_EPOCHS, \n",
        "                   warmup_rate=0.1):\n",
        "  \"\"\"Creates an AdamWeightDecay optimizer with learning rate schedule.\"\"\"\n",
        "  train_data_size = ds_info.splits['train'].num_examples\n",
        "  \n",
        "  steps_per_epoch = int(train_data_size / batch_size)\n",
        "  num_train_steps = steps_per_epoch * epochs\n",
        "  num_warmup_steps = int(warmup_rate * num_train_steps)  \n",
        "\n",
        "  # Creates learning schedule.\n",
        "  lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n",
        "      initial_learning_rate=learning_rate,\n",
        "      decay_steps=num_train_steps,\n",
        "      end_learning_rate=0.0)  \n",
        "  \n",
        "  return optimization.AdamWeightDecay(\n",
        "      learning_rate=lr_schedule,\n",
        "      weight_decay_rate=0.01,\n",
        "      epsilon=1e-6,\n",
        "      exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KX_Hzl3l6w-H"
      },
      "outputs": [],
      "source": [
        "optimizer = bert_optimizer(learning_rate=1e-4)\n",
        "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "metrics = tf.metrics.SparseCategoricalAccuracy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ptn9Cupe6z7o"
      },
      "outputs": [],
      "source": [
        "fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,\n",
        "                   epochs=TRAIN_EPOCHS,\n",
        "                   validation_batch_size=EVAL_BATCH_SIZE, \n",
        "                   validation_data=ind_eval_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0ZK5PBwW61jd"
      },
      "outputs": [],
      "source": [
        "sngp_model = SNGPBertClassifier()\n",
        "sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)\n",
        "sngp_model.fit(train_examples, train_labels, **fit_configs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cpDsgTYx63tO"
      },
      "source": [
        "### Evaluate OOD performance"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d5NGVe7L67bB"
      },
      "source": [
        "Evaluate how well the model can detect the unfamiliar out-of-domain queries. For rigorous evaluation, use the OOD evaluation dataset `ood_eval_dataset` built earlier."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "yyLgt_lL7APo"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "\n",
        "def oos_predict(model, ood_eval_dataset, **model_kwargs):\n",
        "  oos_labels = []\n",
        "  oos_probs = []\n",
        "\n",
        "  ood_eval_dataset = ood_eval_dataset.batch(EVAL_BATCH_SIZE)\n",
        "  for oos_batch in ood_eval_dataset:\n",
        "    oos_text_batch = oos_batch[\"text\"]\n",
        "    oos_label_batch = oos_batch[\"label\"] \n",
        "\n",
        "    pred_logits = model(oos_text_batch, **model_kwargs)\n",
        "    pred_probs_all = tf.nn.softmax(pred_logits, axis=-1)\n",
        "    pred_probs = tf.reduce_max(pred_probs_all, axis=-1)\n",
        "\n",
        "    oos_labels.append(oos_label_batch)\n",
        "    oos_probs.append(pred_probs)\n",
        "\n",
        "  oos_probs = tf.concat(oos_probs, axis=0)\n",
        "  oos_labels = tf.concat(oos_labels, axis=0) \n",
        "\n",
        "  return oos_probs, oos_labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dmc2tVXs6_uo"
      },
      "source": [
        "Computes the OOD probabilities as $1 - p(x)$, where $p(x)=softmax(logit(x))$ is the predictive probability."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_9aFVVDO7C7o"
      },
      "outputs": [],
      "source": [
        "sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_PC0wwZp7GJD"
      },
      "outputs": [],
      "source": [
        "ood_probs = 1 - sngp_probs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AsandMTX7HjX"
      },
      "source": [
        "Now evaluate how well the model's uncertainty score `ood_probs` predicts the out-of-domain label. First compute the Area under precision-recall curve (AUPRC) for OOD probability v.s. OOD detection accuracy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0u5Wx8AP7Mdx"
      },
      "outputs": [],
      "source": [
        "precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "axcctOsh7N5A"
      },
      "outputs": [],
      "source": [
        "auprc = sklearn.metrics.auc(recall, precision)\n",
        "print(f'SNGP AUPRC: {auprc:.4f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U_GEqxq-7Q1Y"
      },
      "source": [
        "This matches the SNGP performance reported at the CLINC OOS benchmark under the [Uncertainty Baselines](https://github.com/google/uncertainty-baselines)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8H4vYcyd7Ux2"
      },
      "source": [
        "Next, examine the model's quality in [uncertainty calibration](https://scikit-learn.org/stable/modules/calibration.html), i.e., whether the model's predictive probability corresponds to its predictive accuracy. A well-calibrated model is considered trust-worthy, since, for example, its predictive probability $p(x)=0.8$ means that the model is correct 80% of the time."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x5GxrSWJ7SYn"
      },
      "outputs": [],
      "source": [
        "prob_true, prob_pred = sklearn.calibration.calibration_curve(\n",
        "    ood_labels, ood_probs, n_bins=10, strategy='quantile')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ozzJM-D-7XVq"
      },
      "outputs": [],
      "source": [
        "plt.plot(prob_pred, prob_true)\n",
        "\n",
        "plt.plot([0., 1.], [0., 1.], c='k', linestyle=\"--\")\n",
        "plt.xlabel('Predictive Probability')\n",
        "plt.ylabel('Predictive Accuracy')\n",
        "plt.title('Calibration Plots, SNGP')\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "36M6HeHx7ZI4"
      },
      "source": [
        "## Resources and further reading"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xdFTpyaP0A-N"
      },
      "source": [
        "* See the [SNGP tutorial](https://www.tensorflow.org/tutorials/understanding/sngp) for a detailed walkthrough of implementing SNGP from scratch. \n",
        "* See [Uncertainty Baselines](https://github.com/google/uncertainty-baselines)  for the implementation of SNGP model (and many other uncertainty methods) on a wide variety of benchmark datasets (e.g., [CIFAR](https://www.tensorflow.org/datasets/catalog/cifar100), [ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012), [Jigsaw toxicity detection](https://www.tensorflow.org/datasets/catalog/wikipedia_toxicity_subtypes), etc).\n",
        "* For a deeper understanding of the SNGP method, check out the paper [Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness](https://arxiv.org/abs/2006.10108).\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "uncertainty_quantification_with_sngp_bert.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1rpzuIuHNW4nnnj5mi1NhV9gjmiRy_QWB",
          "timestamp": 1622128463249
        },
        {
          "file_id": "/piper/depot/google3/third_party/tensorflow_text/g3doc/tutorials/uncertainty_quantification_with_sngp_bert.ipynb?workspaceId=markdaoust:no-nightly::citc",
          "timestamp": 1622127860630
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
